package cn.doitedu.ml.tfidf

import cn.doitedu.commons.util.SparkUtil
import org.apache.log4j.{Level, Logger}

/**
 * @date: 2020/2/22
 * @site: www.doitedu.cn
 * @author: hunter.d 涛哥
 * @qq: 657270652
 * @description:
 */
object TFIDF_SQL {

  def main(args: Array[String]): Unit = {

    Logger.getLogger("org").setLevel(Level.WARN)

    val spark = SparkUtil.getSparkSession(this.getClass.getSimpleName)
    import spark.implicits._


    /**
      * docid,doc
      * 1,    a a a a a a x x y  ->[0,0,0,0,6,0,2,0,0,1,0,0]
      * 2,    b b b x y          ->
      * 3,    c c x y            ->
      * 4,    d x                ->
      */
    val df = spark.read.option("header","true").csv("userprofile/data/demo/tfidf/docs.txt")

    // 将定义好的udf注册到sql解析引擎
    import cn.doitedu.ml.util.VecUtil._
    spark.udf.register("doc2tf",doc2tf)

    // 1. 将原始数据，加工成tf值 词向量
    val tf_df = df.selectExpr("docid","doc2tf(doc,26) as tfarr")
    println("tf特征向量结果： ----------")
    tf_df.show(100,false)

    // 2.利用tf_df，计算出每个词所出现的文档数
    // 将tf_df做一个变换： tf数组中的非零值全部替换成1，以便于后续的计数操作
    spark.udf.register("arr2one",arr2One)
    val tf_df_one = tf_df.selectExpr("docid","arr2one(tfarr) as flag")

    tf_df_one.show(10,false)

    // 接下来将tf_df_one这个dataframe中的所有行的flag数组，对应位置的元素累加到一起==》 该位置的词所出现过的文档数
    import cn.doitedu.ml.util.ArraySumUDAF
    spark.udf.register("arr_sum",ArraySumUDAF)
    val docCntDF = tf_df_one.selectExpr("arr_sum(flag) as doc_cnt")
    docCntDF.show(10,false)

    // 接着，将上面的结果：每个词所出现的文档数 ==》 IDF:  lg(文档总数/(1+词文档数))
    val docTotal: Long = df.count() // 文档总数
    spark.udf.register("cnt2idf",docCntArr2Idf)


    //val idfDF = docCntDF.selectExpr("cnt2idf(doc_cnt,"+docTotal+")")
    val idfDF = docCntDF.selectExpr("cnt2idf(doc_cnt,"+docTotal+")")  // cnt2idf(doc_cnt,4)

    println("idf特征向量结果： ----------")
    idfDF.show(10,false)

    // TODO 将idfDF这个表  和  tfDF表，综合相乘得到  tfidf表

    spark.close()
  }



}
